from torch import nn
import torch
from einops.layers.torch import Rearrange


class MaskedLinear(nn.Linear):
    def __init__(self, in_features: int, out_features: int, bias: bool = True, device = None, dtype = None, freezing_rate: float=-1)  -> None:
        super().__init__(in_features=in_features, out_features=out_features, bias=bias, device=device, dtype=dtype)
        
        self.freezing_rate = freezing_rate
        if self.freezing_rate <= 0:
            return 
        
        self.num_to_freeze = round(in_features*out_features*freezing_rate)
        assert self.num_to_freeze > 0

        print(f"freezgin_rate:{self.freezing_rate}")
        print("setting top num_to_freeze values in smask to 1")
        print("the corresponding values in the weight tensor will be set to zero and frozen")
        mask = torch.rand((out_features, in_features), device=device)
        r= torch.topk(mask.view(-1), self.num_to_freeze)
        mask= torch.zeros_like(mask,device=device)        
        mask = mask.flatten()
        mask[r.indices] = mask[r.indices].fill_(1)
        mask = mask.reshape(out_features, in_features)
        
        self.mask = mask.to(bool)
        s=torch.sum(self.mask).item()
        num_params = in_features*out_features
        p=100*s/num_params
        print(f'applying smask to layer  - freezing {s} of {num_params} weights ({p:.4f}%)')
        with torch.no_grad(): self.weight[ self.mask ] = 0
    
    
    def mask_grad(self):
        if self.freezing_rate <=0: return
        with torch.no_grad():self.weight.grad[self.mask] = 0

from .mlpmixer import PreNormResidual


class PaddingFlat(nn.Module):
    def __init__(self, in_shape, out_shape, mode="constant", value=0):
        super().__init__()
        self.mode = mode
        self.value = value

        num_pad_channel = out_shape[-1] - in_shape[-1]
        
        assert  num_pad_channel >=0        
        pad_left = int(num_pad_channel/2)
        pad_right = num_pad_channel - pad_left
    
        self.pad = (pad_left,pad_right)

        
    def forward(self, x):
        return   torch.nn.functional.pad(x, self.pad, mode=self.mode, value=self.value)



def MaskedMLP(*, image_size,channels, width,  depth,
                  num_classes, expansion_factor=-1,freezing_rate=1, patch_size=-1, use_skip=True,
                  mask_device="cuda"):
        """_summary_
        args 
            depth:int 
                the number of hidden blocks. e.g. depth=1 -> total hidden layer
            expansion_factor:  if == -1, omit hidden layer in mlp-block
            
        """
        if patch_size == -1:
            in_dim  = image_size**2*channels
            preblock = [ Rearrange('b c h w -> b (h w c)'),
                        nn.Linear(in_dim, width),
                        nn.GELU()]        
            pre_dim = width
        elif patch_size > 1:
            pair = lambda x: x if isinstance(x, tuple) else (x, x)
            image_h, image_w = pair(image_size)
            assert (image_h % patch_size) == 0 and (image_w % patch_size) == 0, 'image must be divisible by patch size'
            num_patches = (image_h // patch_size) * (image_w // patch_size)
            num_channels =(patch_size ** 2) * channels 
            dim = int(width/num_patches)
            preblock = [
                Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
                nn.Linear(num_channels, dim),
                nn.GELU(),
                Rearrange('b s c -> b (s c)')
            ]
            pre_dim = num_patches*dim ### close to width
            if pre_dim< width:
                preblock += [PaddingFlat((pre_dim,),(width,) )]
            pre_dim = width 
        else:
            raise ValueError()

        blocks =[]
        for ell in range(depth):
            i_width = pre_dim if ell == 0 else width
            if expansion_factor <= 0:
                tmp_block =[ MaskedLinear(i_width, width, freezing_rate=freezing_rate, device=mask_device), nn.GELU()]
            else:
                hidden_dim = round(width*expansion_factor)
                assert hidden_dim > 0
                tmp_block =[ MaskedLinear(i_width, hidden_dim,freezing_rate=freezing_rate, device=mask_device), 
                                   nn.GELU(),
                                MaskedLinear(hidden_dim,width,freezing_rate=freezing_rate, device=mask_device), 
                                nn.GELU()
                                ]

            if use_skip:
                blocks+= [PreNormResidual(dim=i_width, fn=nn.Sequential(*tmp_block))]
            else:
                blocks+= tmp_block

        return nn.Sequential(*preblock, *blocks, nn.Linear(width, num_classes))            

def HiddenMaskedMLP(width,  depth,
                  num_classes, expansion_factor=-1,freezing_rate=1, patch_size=-1,use_skip=False, mask_device="cuda"):
        """_summary_
        PPFC is Removed
        """
        blocks =[nn.Flatten()]
        for ell in range(depth):
            tmp_blocks=[]
            if expansion_factor <= 0:
                tmp_blocks += [MaskedLinear(width, width, freezing_rate=freezing_rate, device=mask_device), nn.GELU()]
            else:
                hidden_dim = round(width*expansion_factor)
                assert hidden_dim > 0
                tmp_blocks += [MaskedLinear(width, hidden_dim,freezing_rate=freezing_rate, device=mask_device), nn.GELU()]                
                tmp_blocks += [MaskedLinear(hidden_dim,width,freezing_rate=freezing_rate, device=mask_device), nn.GELU()]                
            
            if use_skip:
                blocks += [PreNormResidual(dim=width, fn=nn.Sequential(*tmp_blocks) )]
            else:
                blocks += tmp_blocks

        return nn.Sequential( *blocks, nn.Linear(width, num_classes))            
